#%% 
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--test_training_speed', type=bool, default=False)
args = parser.parse_args()

#%% 
df = pd.read_csv(
    "OnlineNewsPopularity/OnlineNewsPopularity.csv", 
    skipinitialspace=True
)
sr_Y = df['shares']
df_X = df.drop(
    ['url', 'timedelta', 'shares'], # non-predictive & taget features
    axis=1
)# 58 features remaining


#%% preprocessing
'''
Input features:

After dropping non-predictive input features, there are 58 input features remaining.

Data types: All the input features have numerical values, but by counting the unique values, we can tell some of them are actually binary. And some of them are related in content. So we can merge them into multinomial features encoded by integers to reduce dimension.

Feature scales: The scales(range) of these features are quite different. Standardization are needed before fitting models on them.

Outliers: If we define values greater than 5 std from mean as outliers, some features have considerable amount of outliers thus outlier removal is needed as well in preprocessing.
'''

def outlierCounts(series):
    centered = np.abs(series - series.mean())
    mask     = centered >= (5 * series.std())
    return len(series[mask])

def uniqueValueCount(series):
    return len(series.unique())

input_feats          = df_X.dtypes.reset_index()
input_feats.columns  = ['name', 'dtype']
input_feats['mean']  = df_X.mean().reset_index(drop=True)
input_feats['std']   = df_X.std().reset_index(drop=True)
input_feats['range'] = (df_X.max() - df_X.min())\
    .reset_index(drop=True)

input_feats['unique_values_count'] = df_X\
    .apply(uniqueValueCount, axis=0)\
    .reset_index(drop=True)
    
input_feats['outliers_count'] = df_X\
    .apply(outlierCounts, axis=0)\
    .reset_index(drop=True)


#%% Merge Binary Features
'''
Among those binary features, there are 6 describing the content categories of news and 7 describing the publish weekday. We can merge them and reduce the number of features to 47.
'''
def mergeFeatures(df, old_feats, new_feat):
    """ merge binary features in dataframe with int encoding.
    in: dataframe, binaryFeatureNames and multinomialFeatureName
    out: newDataframe
    """
    counter = 0
    df[new_feat] = counter
    
    for old_feat in old_feats:
        counter += 1
        df.loc[df[old_feat] == 1.0, new_feat] = counter
        del df[old_feat]
    
    return df

data_channels = [
    'data_channel_is_lifestyle',
    'data_channel_is_entertainment',
    'data_channel_is_bus',
    'data_channel_is_socmed',
    'data_channel_is_tech',
    'data_channel_is_world'
]
weekdays = [
    'weekday_is_monday',
    'weekday_is_tuesday',
    'weekday_is_wednesday',
    'weekday_is_thursday',
    'weekday_is_friday',
    'weekday_is_saturday',
    'weekday_is_sunday'
]
df = mergeFeatures(df_X, data_channels, 'data_channel')
df = mergeFeatures(df_X, weekdays, 'pub_weekday')


#%% Remove Outliers and Normalize features
import sklearn.preprocessing as prep

# remove outliers
for col in df_X.columns:
    centered = np.abs(df_X[col]-df_X[col].mean())
    mask     = centered <= (5 * df_X[col].std())
    df_X     = df_X[mask]

sr_Y = sr_Y[df_X.index]

def standarize(arr_X):
    arr_X = prep.MinMaxScaler().fit_transform(arr_X)
    return arr_X - arr_X.mean(axis=1).reshape(-1, 1)

arr_X = df_X.values
arr_X = standarize(arr_X)

#%% Binarize Y
'''
As we mentioned before, we use the median value 1400 as the threshold to binarize target feature and divide the data points into 2 classes. Because of the outlier removal, the sizes of 2 classes are not the same anymore. But they are still more or less balanced.
'''
arr_Y = prep.binarize(
    sr_Y.values.reshape(-1, 1), 
    threshold=1400 # using original median as threshold
) 
sr_Y  = pd.Series(arr_Y.ravel())

unique_items, counts = np.unique(arr_Y, return_counts=True)
#%% 
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    arr_X, arr_Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)


# Data scaling
num_features = X_train.shape[1]
# feature_names = X_train.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Train Model
import pickle
import os.path
import sys
sys.path.append('..')
from copy import deepcopy
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim
import lightgbm as lgb
from lightgbm.callback import log_evaluation, early_stopping
device = torch.device('cuda')
#%% 
if os.path.isfile('news model.pkl'):
    print('Loading saved model')
    with open('news model.pkl', 'rb') as f:
        model = pickle.load(f)

else:
    # Setup
    params = {
        "max_bin": 512,
        "learning_rate": 0.05,
        "boosting_type": "gbdt",
        "objective": "binary",
        "metric": "binary_logloss",
        "num_leaves": 10,
        "verbose": -1,
        "min_data": 100,
        "boost_from_average": True
    }

    # More setup
    d_train = lgb.Dataset(X_train, label=Y_train)
    d_val = lgb.Dataset(X_val, label=Y_val)
    callbacks = [log_evaluation(period=1000), early_stopping(stopping_rounds=50)]
    # Train model
    model = lgb.train(params, d_train, 10000, valid_sets=[d_val],
                      callbacks=callbacks)
    
    # Save model
    with open('news model.pkl', 'wb') as f:
        pickle.dump(model, f)

#%% Train surrogate
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import Surrogate, KLDivLoss

# Select device
device = torch.device('cuda')

#%% 
# Check for model
if os.path.isfile('news surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('news surrogate.pt').to(device)
    surrogate = Surrogate(surr, num_features)

else:
    # Create surrogate model
    surr = nn.Sequential(
        MaskLayer1d(value=0, append=True),
        nn.Linear(2 * num_features, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 2)).to(device)

    # Set up surrogate object
    surrogate = Surrogate(surr, num_features)

    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Train
    surrogate.train_original_model(
        X_train,
        X_val,
        original_model,
        batch_size=64,
        max_epochs=100,
        loss_fn=KLDivLoss(),
        validation_samples=10,
        validation_batch_size=10000,
        verbose=True)

    # Save surrogate
    surr.cpu()
    torch.save(surr, 'news surrogate.pt')
    surr.to(device)

#%% Train SimSHAP

from simshap.simshap_sampling import SimSHAPSampling
from models import SimSHAPTabular
import time
# Check for model
if os.path.isfile('news simshap.pt'):
    print('Loading saved explainer model')
    explainer = torch.load('news simshap.pt').to(device)
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
            

else:
    # Create explainer model
    explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=128, out_dim=2).to(device)

    # Set up FastSHAP object
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
    # Train
    if args.test_training_speed:
        start = time.time()
    simshap.train(
        X_train,
        X_val[:100],
        batch_size=2048,
        num_samples=64,
        max_epochs=1000,
        lr=1e-3,  
        bar=False,
        validation_samples=128,
        verbose=True, 
        lookback=10,
        lr_factor=0.5)
    if args.test_training_speed:
        print('simshap training time: ', time.time()-start)
    # Save explainer
    explainer.cpu()
    torch.save(explainer, 'news simshap.pt')
    explainer.to(device)

#%% fastshap
from simshap.fastshap_plus import FastSHAP

# Check for model
if os.path.isfile('news fastshap.pt'):
    print('Loading saved explainer model')
    explainer_fastshap = torch.load('news fastshap.pt').to(device)
    fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                        link=nn.Identity())

else:
    # Create explainer model
    explainer_fastshap = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 2 * num_features)).to(device)

    # Set up FastSHAP object
    fastshap = FastSHAP(explainer_fastshap, surrogate, 
                        link=nn.Identity(), normalization='additive')

    # Train
    if args.test_training_speed:
        start = time.time()
    fastshap.train(
        X_train,
        X_val[:100],
        batch_size=32,
        num_samples=32,
        max_epochs=200,
        validation_samples=128,
        verbose=True)
    if args.test_training_speed:
        print('fastshap training time: ', time.time()-start)
    # Save explainer
    explainer_fastshap.cpu()
    torch.save(explainer_fastshap, 'news fastshap.pt')
    explainer_fastshap.to(device)


#%% Compare with KernelSHAP
import matplotlib.pyplot as plt
# Setup for KernelSHAP
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()

# Select example
np.random.seed(200)
ind = np.random.choice(len(X_test))
x = X_test[ind:ind+1]
y = int(Y_test[ind])

# Run evoshap
simshap_values = simshap.shap_values(x)[0].transpose(1,0)
fastshap_values = fastshap.shap_values(x)[0]
# Run KernelSHAP to convergence
game = shapreg.games.PredictionGame(imputer, x)
shap_values, all_results = shapreg.shapley.ShapleyRegression(
    game, batch_size=32, paired_sampling=False, detect_convergence=True,
    bar=True, return_all=True)

# Create figure
plt.figure(figsize=(9, 5.5))

# Bar chart
width = 0.75
val_num_features = 10
kernelshap_iters = 128
plt.bar(np.arange(val_num_features) - width / 2, shap_values.values[:10, y],
        width / 4, label='True SHAP values', color='tab:gray')
plt.bar(np.arange(val_num_features) - width / 4, simshap_values[:10, y],
        width / 4, label='SimSHAP', color='tab:green')
plt.bar(np.arange(val_num_features),
        fastshap_values[:10, y],
        width / 4, label='fastSHAP', color='tab:blue')
plt.bar(np.arange(val_num_features) + width / 4,
        all_results['values'][list(all_results['iters']).index(kernelshap_iters)][:10, y],
        width / 4, label='KernelSHAP @ {}'.format(kernelshap_iters), color='tab:red')

# Annotations
plt.legend(fontsize=16)
plt.tick_params(labelsize=14)
plt.ylabel('SHAP Values', fontsize=16)
plt.title('News Explanation Example', fontsize=18)

plt.tight_layout()
plt.savefig('news result.png')
plt.show()
